#!/usr/bin/perl -w

use strict;
use autodie;
use File::Spec;
use File::Basename;
use List::Util;

my $OFROOT = "funcs";
my $GAS = 10;
my $VERBOSE = 0;
my $DEBUG = 0;
my $STRICT = 0;
my $TESTEXE = "\$HOME/souper-regehr/build/bulk_tests";

my $GENBOOLS = 1;
my $GENCONSTANTS = 1;

my $ZERO = 0;

my $RUN = 0;
my $TESTONE = 0;
my $GEN = 0;
my $TESTFILE = 0;

my $TESTFN;
my $NFUNCS;
my $TESTSPEC;

sub usage() {
    die "usage: $0 gen N | testfile FILE.CPP | run N | testspec SPEC";
}

usage() unless scalar(@ARGV) == 2;

if ($ARGV[0] eq "gen") {
    $GEN = 1;
    $NFUNCS = $ARGV[1];
} elsif ($ARGV[0] eq "testfile") {
    $TESTFILE = 1;
    $TESTFN = $ARGV[1];
    die unless -f $TESTFN;
} elsif ($ARGV[0] eq "run") {
    $RUN = 1;
    $NFUNCS = $ARGV[1];
} elsif ($ARGV[0] eq "testspec") {
    $TESTONE = 1;
    $TESTSPEC = $ARGV[1];
} else {
    usage();
}

my $SHLIB;
# need to "export OSTYPE" in .bash_profile
my $ostype = $ENV{"OSTYPE"};
if (defined($ostype) && $ostype eq "darwin19") {
    $SHLIB = File::Spec->rel2abs("./funcs.dylib");
} else {
    $SHLIB = File::Spec->rel2abs("./funcs.so");
}

my $SCRIPTDIR = File::Basename::dirname(__FILE__);

my @constants = (
    "APInt::getAllOnesValue(WIDTH)",
    "APInt::getMinValue(WIDTH)",
    "APInt::getSignedMaxValue(WIDTH)",
    "APInt::getSignedMinValue(WIDTH)",
    "APInt(WIDTH, WIDTH)",
    "APInt(WIDTH, 1)",
    );

my @binops1 = (
    "operator&",
    "operator|",
    "operator^",
    "operator+",
    "operator-",
    );

my @binops2 = (
    "operator*",
    "udiv",
    "sdiv",
    "urem",
    "srem",
    );

my @unops1 = (
    "~",
    "-",
    );

my @unops2 = (
    # "byteSwap",
    "reverseBits",
    );

my @gen_narrow = (
    "getActiveBits()",
    "getMinSignedBits()",
    "countLeadingZeros()",
    "countLeadingOnes()",
    "getNumSignBits()",
    "countTrailingZeros()",
    "countTrailingOnes()",
    "countPopulation()",
    "logBase2()",
    "ceilLogBase2()",
    );

my @unary_narrow1 = (
    "setBit",
    "setBitsFrom",
    "setLowBits",
    "setHighBits",
    "clearBit",
    # "clearLowBits",
    "flipBit",
    );

my @unary_narrow2 = (
    "ashr",
    "lshr",
    "shl",
    "rotl",
    "rotr",
    );

my @binary_narrow = (
    "setBits",
    );

my @unary_preds = (
    ".isNegative()",
    ".isNonNegative()",
    ".isStrictlyPositive()",
    ".isAllOnesValue()",
    ".isNullValue()",
    ".isOneValue()",
    ".isMaxSignedValue()",
    ".isMinSignedValue()",
    ".isPowerOf2()",
    ".getBoolValue()",
    ".isMask()",
    ".isShiftedMask()",
    );

my @binary_preds = (
    "operator==",
    "operator!=",
    "slt",
    "sgt",
    "ult",
    "ugt",
    "sle",
    "ule",
    "sge",
    "uge",
    );

my $choice_in;
my $choice_out;
my $empty;
my @avail_wide;
my $ident;

sub choose($$) {
    (my $nchoices, my $site) = @_;
    die unless $nchoices > 0;
    my $choice;
  AGAIN:
    if ($empty or $choice_in eq "") {
        die if $STRICT;
        $choice = 0;
        $empty = 1;
    } else {
        if ($choice_in =~ /^([0-9]+)([a-z])/) {
            $choice = $1;
            my $location = $2;
            my $len = length($choice) + length($location);
            if ($location eq $site) {
                my $len = length($choice) + length($location);
                substr($choice_in, 0, $len) = "";
            } else {
                die "strict!" if $STRICT;
                if (rand() < 0.5) {
                    substr($choice_in, 0, $len) = "";
                    goto AGAIN;
                }
                $choice = int(rand(100));
            }
        } else {
            substr($choice_in, 0, 1) = "";
            $choice = int(rand(100));
        }
    }
    while ($choice >= $nchoices) {
        $choice -= $nchoices;
    }
    print "($choice) " if $VERBOSE;
    $choice_out .= "${choice}${site}";
    return $choice;
}

sub make_expr();

sub make_two_exprs() {
    (my $l, my $s1) = make_expr();
    (my $r, my $s2) = make_expr();
    if (choose(2, "b")) {
        return ($l, $r, $s1 . $s2);
    } else {
        return ($r, $l, $s1 . $s2);
    }
}

sub new_wide_value() {
    my $name = "t${ident}";
    $ident++;
    push @avail_wide, $name;
    return $name;
}

sub make_expr() {
    my $c;
    my $CPP = "";
    if ($GENBOOLS) {
        $c = choose(5, "d");
    } else {
        $c = choose(4, "d");
    }

    # this has to be choice zero since the choose operator returns
    # zero when it runs out of gas, and this is the option which lets
    # us escape without creating new stuff
    if ($c == 0) {
        return ($avail_wide[choose(scalar(@avail_wide), "c")], "");
    }

    if ($c == 1) {
        (my $a, my $s) = make_expr();
        $CPP .= $s;
        my $name = new_wide_value();
        $c = choose(scalar(@unops1) + scalar(@unops2), "e");
        if ($c < scalar(@unops1)) {
            $CPP .= "  const APInt $name = ".$unops1[$c].$a.";\n";
        } else {
            $c -= scalar(@unops1);
            $CPP .= "  const APInt $name = ".$a.".".$unops2[$c]."();\n";
        }
        return ($name, $CPP);
    }

    if ($c == 2) {
        (my $l, my $r, my $s) = make_two_exprs();
        $CPP .= $s;
        my $i = choose(scalar(@binops1) + scalar(@binops2), "f");
        my $name = new_wide_value();
        if ($i < scalar(@binops1)) {
            $CPP .= "  const APInt $name = ".$binops1[$i]."(".$l.", ".$r.");\n";
        } else {
            $i -= scalar(@binops1);
            my $o = $binops2[$i];
            if ($o =~ /div/ || $o =~ /rem/) {
                $CPP .= "  const APInt $name = (${r} == 0) ? APInt(WIDTH, 0) : (".$l.".".${o}."(".$r."));\n";
            } else {
                $CPP .= "  const APInt $name = ".$l.".".${o}."(".$r.");\n";
            }
        }
        return ($name, $CPP);
    }

    if ($c == 3) {
        (my $a, my $b, my $s) = make_two_exprs();
        $CPP .= $s;
        my $name = new_wide_value();
        my $narrow = $gen_narrow[choose(scalar(@gen_narrow), "j")];
        # better way to do this?
        my $limited = "(${b}.${narrow} >= WIDTH) ? 0 : (${b}.${narrow})";
        my $i = choose(scalar(@unary_narrow1) + scalar(@unary_narrow2), "k");
        if ($i < scalar(@unary_narrow1)) {
            my $op = $unary_narrow1[$i];
            $CPP .= "  APInt ${name} = ${a};\n";
            $CPP .= "  ${name}.${op}($limited);\n";
        } else {
            $i -= scalar(@unary_narrow1);
            my $op = $unary_narrow2[$i];
            $CPP .= "  const APInt ${name} = ${a}.${op}($limited);\n";
        }
        return ($name, $CPP);
    }

    if ($c == 4) {
        (my $l1, my $r1, my $s1) = make_two_exprs();
        $CPP .= $s1;
        my $pred;
        if (choose(2, "g")) {
            (my $l2, my $r2, my $s2) = make_two_exprs();
            $CPP .= $s2;
            my $p = $binary_preds[choose(scalar(@binary_preds), "h")];
            $pred = $l2.".".$p."(".$r2.")";
        } else {
            (my $a, my $s2) = make_expr();
            $CPP .= $s2;
            my $p = $unary_preds[choose(scalar(@unary_preds), "i")];
            $pred = $a.$p;
        }
        my $name = new_wide_value();
        $CPP .= "  const APInt $name = ($pred) ? (${l1}) : (${r1});\n";
        return ($name, $CPP);
    }

    die;
}

sub rand_choices($) {
    (my $n) = @_;
    my $s = "";
    for (my $i=0; $i<$n; $i++) {
        $s .= int(rand(100))."a";
    }
    return $s;
}

my @funcs;

sub generate($$) {
    my $CPP;
    (my $cin, my $number) = @_;
    $choice_in = $cin;
    my $choice_orig = $cin;
    $choice_out = "";
    $empty = 0;
    @avail_wide = ("x.Zero", "x.One", "y.Zero", "y.One");
    $ident = 0;
    if ($GENCONSTANTS) {
        push @avail_wide, @constants;
    }
    my $name = sprintf("f_%06d", $number);
    push @funcs, $name;
    $CPP .= "static KnownBits ${name}(const KnownBits &x, const KnownBits &y) {\n";
    $CPP .= "  llvm::outs() << \"${name}\";\n" if $DEBUG;
    $CPP .= "  const int WIDTH = x.One.getBitWidth();\n";
    print "$choice_in\n" if $VERBOSE;
    (my $e1, my $s1) = make_expr();
    $empty = 0;
    $CPP .= $s1;
    print "$choice_in\n" if $VERBOSE;
    (my $e2, my $s2) = make_expr();
    $CPP .= $s2;
    print "\n\n\n" if $VERBOSE;
    $CPP .= "  KnownBits ret(WIDTH);\n";

    if ($ZERO) {
	$CPP .= "  ret.Zero = $e1;\n";
	$CPP .= "  ret.One = APInt::getMinValue(WIDTH);\n";
    } else {
	$CPP .= "  ret.One = $e1;\n";
	$CPP .= "  ret.Zero = APInt::getMinValue(WIDTH);\n";
    }
    #$CPP .= "  ret.One = $e2 & ~$e1;\n";

    $CPP .= "  return ret;\n";
    $CPP .= "}\n//  in: ${choice_orig}\n// out: ${choice_out}\n\n";
    return $CPP;
}

sub make_header() {
    my $CPP = "";
    $CPP .= "#include <llvm/ADT/APInt.h>\n";
    $CPP .= "#include <llvm/Support/KnownBits.h>\n";
    $CPP .= "#include \"${OFROOT}.h\"\n";
    $CPP .= "using namespace llvm;\n";
    $CPP .= "\n";
    return $CPP;
}

sub make_footer() {
    my $CPP = "";
    my $HPP = "";

    my $TYPEDEF = "typedef llvm::KnownBits (*TestFn)(const llvm::KnownBits &, const llvm::KnownBits &);\n";
    $HPP .= $TYPEDEF;

    $CPP .= $TYPEDEF;
    $CPP .= "TestFn AllFuncs[] = {\n";

    foreach my $f (@funcs) {
        $CPP .= "  $f,\n";
    }

    $CPP .= "  nullptr,\n";
    $CPP .= "};\n\n";

    return ($CPP, $HPP)
}

sub go($) {
    (my $lref) = @_;
    my @specs = @{$lref};
    my @newspecs = ();

    open my $CPPFILE, ">${OFROOT}.cpp" or die;
    open my $HPPFILE, ">${OFROOT}.h" or die;

    print $CPPFILE make_header();

    my $i = 0;
    foreach my $c (@specs) {
        my $CPP = generate($c, $i);
        push @newspecs, $choice_out;
        print $CPPFILE $CPP;
        $i++;
    }

    (my $CPP, my $HPP) = make_footer();
    print $CPPFILE $CPP;
    print $HPPFILE $HPP;

    close $CPPFILE;
    close $HPPFILE;
    return \@newspecs;
}

# biased towards lower numbers
sub weighted_index($) {
    (my $N) = @_;
    my @l = (
        int(rand($N/2)),
        int(rand($N/2)),
        int(rand($N/2)),
        int(rand($N/2)),
        );
    return List::Util::min(@l);
}

# fixme this is rough, needs work
sub mutate($) {
    (my $in) = @_;
    # print "mutating $in :\n";
    do {
        my $r = rand();
        if ($r < 0.333) {
            # print "modify\n";
            my $idx;
            my $c = 0;
            while (1) {
                $idx = int(rand(length($in)));
                last if substr($in, $idx, 1) =~ /[0-9]/;
                last if ($c == 100);
                $c++;
            }
            if ($c == 100) {
                $in = rand_choices($GAS);
            } else {
                substr($in, $idx, 1) = int(rand(10));
            }
        } elsif ($r < 0.666) {
            # print "add\n";
            my $idx = int(rand(length($in)));
            my $n = int(rand(6));
            my $s = "";
            for (my $i=0; $i<$n; $i++) {
                $s .= int(rand(100))."a";
            }
            substr($in, $idx) = $s;
        } else {
            # print "remove\n";
            my $idx = int(rand(length($in)));
            substr($in, $idx, int(rand(12))) = "";
        }
    } while (rand() < 0.4);
    # print "done\n";
    return $in;
}

if ($TESTFILE) {
    $STRICT = 1;
    open my $INF, "<$TESTFILE" or die;
    my $ntests = 0;
    while (my $line = <$INF>) {
        next unless $line =~ /\/\/ out: ([a-zA-Z0-9]+)$/;
        my $spec_in = $1;
        print "<${spec_in}>\n" if $VERBOSE;
        my $func = generate($spec_in, 0);
        $ntests++;
        die unless $func =~ /\/\/ out: ([a-zA-Z0-9]+)$/m;
        my $spec_out = $1;
        print "<${spec_out}>\n\n" if $VERBOSE;
        die if ($spec_in ne $spec_out);
    }
    close $INF;
    print "$ntests tests performed\n";
    exit(0);
}

sub get_scores($) {
    (my $lref) = @_;
    my @specs = @{$lref};
    @funcs = ();
    $lref = go(\@specs);
    my @newspecs = @{$lref};
    system("$SCRIPTDIR/compile.sh");
    my $cmd = "$TESTEXE $SHLIB";
    # print "$cmd\n";
    open my $INF, "$cmd |" or die;
    my %scores = ();
    my $cnt = 0;
    # TODO add size to score
    while (my $line = <$INF>) {
        if ($line =~ /score = ([0-9]+)$/) {
            $scores{$newspecs[$cnt]} = (1000 * $1) + length($newspecs[$cnt]);
            $cnt++;
        }
    }
    close $INF;
    return (\%scores, \@newspecs);
}

sub crossover($$) {
    (my $a, my $b) = @_;
    my $sa = substr($a, 0, int(rand(length($a))));
    my $sb = substr($b, int(rand(length($b))));
    return $sa . $sb;
}

if ($RUN) {
    my @specs;
    my %xspecs;
    for (my $i=0; $i<$NFUNCS; $i++) {
        my $new = rand_choices($GAS);
        if (!$xspecs{$new}) {
            push @specs, $new;
            $xspecs{$new} = 1;
        }
    }
    my $prev_best = 1e19;
    for (my $gen = 0; $gen < 10000; $gen++) {
        die unless scalar(@specs) == $NFUNCS;
        # print "getting scores\n";
        (my $href, my $lref) = get_scores(\@specs);
        # print "got scores\n";
        my %scores = %{$href};
        my @newspecs = @{$lref};
        die unless scalar(@specs) == scalar(@newspecs);
        @specs = @newspecs;

        my @sorted = sort { $scores{$a} <=> $scores{$b} } keys %scores;

        my $total = 0;
        foreach my $k (@sorted) {
            $total += $scores{$k};
            # print "$k $scores{$k}\n";
        }
        my $avg = int((0.0 + $total) / $NFUNCS);
        my $best = $scores{$sorted[0]};
        print "gen $gen: average = $avg, best = $best\n";

        die if ($best > $prev_best);
        $prev_best = $best;

        @specs = ();
        my %newspecs = ();
        # elitism -- move a few of the best scorers to the next generation
        # unaltered
        for (my $i = 0; $i < 5; $i++) {
            my $spec = $sorted[$i];
            push @specs, $spec;
            die if ($newspecs{$spec});
            $newspecs{$spec} = 1;
            # print "elite: $spec $scores{$spec}\n";
        }
        # mutate some high-scorers
        while (scalar(@specs) < ($NFUNCS * 0.333)) {
            my $idx = weighted_index($NFUNCS);
            my $new = mutate($sorted[$idx]);
            if (!$newspecs{$new}) {
                push @specs, $new;
                $newspecs{$new} = 1;
            }
        }
        # splice some high-scorers
        while (scalar(@specs) < ($NFUNCS * 0.666)) {
            my $idx1 = weighted_index($NFUNCS);
            my $idx2 = weighted_index($NFUNCS);
            my $new = crossover($sorted[$idx1], $sorted[$idx2]);
            if (!$newspecs{$new}) {
                push @specs, $new;
                $newspecs{$new} = 1;
            }
        }
        # finish up with some fresh random stuff
        while (scalar(@specs) < $NFUNCS) {
            my $new = rand_choices($GAS);
            if (!$newspecs{$new}) {
                push @specs, $new;
                $newspecs{$new} = 1;
            }
        }
    }
    exit(0);
}

if ($GEN) {
    my @specs = ();
    for (my $i=0; $i<$NFUNCS; $i++) {
        my $c = rand_choices($GAS);
        push @specs, $c;
    }
    go(\@specs);
    exit(0);
}

if ($TESTONE) {
    my @specs = ();
    push @specs, $TESTSPEC;
    $STRICT = 1;
    my $href = get_scores(\@specs);
    my %scores = %{$href};
    foreach my $k (keys %scores) {
        print "$k $scores{$k}\n";
    }
    exit(0);
}
